/*
 * Copyright (C) 2018-2020 Stefan Westerfeld
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

#include "utils.hh"
#include "shortcode.hh"
#include "wmcommon.hh"

#include <assert.h>

using std::vector;

/* Codes from codetables.de / magma online calculator BKLC (GF(2), N, K) */

static vector<vector<int>> block_65_20_20 = {
  {1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,1,1,0,1,1,1,1,1,1,1,0,1,0,0,0,0,0,0,1,1,1,1,1,1,1,1,0,1,0,1,1,1,1,1,0,0,1,1,1,0,1,1,1},
  {0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,1,0,1,0,1,0,0,0,1,1,0,1,0,0,0,0,0,1,1,1,1,0,1,1,0,0,1,0,0,0,0,0,1,1,1,0,1,1,1,1,1,1},
  {0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,0,0,0,1,1,0,1,1,0,1,0,0,0,0,1,1,1,1,0,0,1,0,1,1,0,0,1,1,1,0,1,0,1,0,1,1,0,1,1},
  {0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,1,1,0,0,1,1,0,1,0,1,1,0,1,0,0,0,1,1,1,1,0,0,0,0,1,0,0,0,1,0,0,1,0,0,0,1,0,1,0,0,1},
  {0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,1,0,0,0,1,0,1,1,0,1,0,0,1,1,1,1,0,0,0,1,1,0,1,0,1,0,1,0,1,1,0,0,1,0,0,0,0},
  {0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,1,0,0,0,1,1,0,1,1,0,1,0,0,0,0,1,1,1,0,0,0,1,0,1,1,0,0,1,1,1,0,1,0,1,0,1,1},
  {0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,1,1,0,0,1,1,0,1,0,1,1,0,1,0,0,0,1,0,1,0,0,0,0,0,1,0,0,0,1,0,0,1,0,0,0,1,0,1},
  {0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,1,0,0,0,1,0,0,0,1,0,1,1,0,1,0,0,1,0,0,0,0,0,0,1,1,0,1,0,1,0,1,0,1,1,0,0,1,0},
  {0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,1,0,1,1,0,1,1,1,1,1,0,0,1,0,1,1,1,0,1,1,1,1,1,1,0,1,0,0,0,1,0,1,1,0,0,1,0,1,1,1,0},
  {0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,1,0,1,1,1,0,1,1,0,0,1,0,1,1,0,1,1,1,0,1,1,0,0,1,0,1,1,1,0,0,1,1,1,1,0,1,0,0},
  {0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,0,0,0,1,0,1,1,0,1,0,0,1,1,1,1,0,0,1,0,0,0,1,0,1,0,1,0,1,1,0,0,1,0,0,0,0,0,0,0,0,1,1,0,1},
  {0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,0,0,0,1,0,1,0,1,1,0,1,0,0,0,1,1,1,1,0,0,1,1,0,0,1,0,1,0,0,1,0,0,0,1,0,1,0,0,1,0,0,0,0,0,1,0},
  {0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,1,1,0,1,0,0,1,0,1,1,0,1,0,1,1,1,1,0,0,1,0,1,0,1,1,1,0,1,1,0,0,1,0,0,1,1,0,0,1,0,0,0,1,0},
  {0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,1,0,0,0,1,0,0,0,1,1,0,1,0,0,0,0,0,1,1,1,1,0,1,0,1,0,1,0,0,0,0,0,1,1,1,0,1,1,1,1,1,1,0,0,1,1,0},
  {0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,0,0,1,1,0,1,0,1,1,0,1,0,0,0,1,1,1,1,0,0,1,1,1,0,0,0,1,0,0,1,0,0,0,1,0,1,0,0,1,0,0,0,0},
  {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,1,0,0,0,0,1,0,0,0,1,0,1,0,0,0,0,0,0,0,1,1,1,0,1,1,0,1,0,1,1,0,0,0,1,0,1,1,1,1,0,0,1,1,1,1,1,1},
  {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,1,0,0,0,1,0,0,1,1,0,0,1,0,0,0,0,0,0,0,0,1,1,0,0,1,1,0,1,0,0,0,1,1,0,0,1,0,1,1,0,0,0,1,1,0,1,1},
  {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,0,0,0,0,1,1,1,0,1,1,1,0,0,0,0,0,0,0,0,0,1,0,0,0,1,0,0,1,1,1,1,0,1,1,1,0,0,1,0,0,0,0,1,0,0,1},
  {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,1,0,1,1,1,0,0,1,1,0,0,1,0,0,0,0,0,0,0,0,1,1,0,1,1,0,0,1,0,0,0,1,1,0,0,1,0,1,1,0,0,0,1,1},
  {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,0,0,1,0,1,1,1,0,1,1,1,0,0,0,0,0,0,0,0,0,1,0,0,1,0,0,0,1,1,1,1,0,1,1,1,0,0,1,0,0,0,0,1},
};

static vector<vector<int>> block_61_16_21 = {
  {1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,1,0,1,1,1,0,1,0,1,1,1,1,1,1,0,0,1,0,0,0,0,1,0,1,0,0,0,0,0,1,1,0,0,1,0,0,1,0,0},
  {0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,1,0,1,1,1,0,1,0,1,1,1,1,1,1,0,0,1,0,0,0,0,1,0,1,0,0,0,0,0,1,1,0,0,1,0,0,1,0},
  {0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,1,0,1,1,1,0,1,0,1,1,1,1,1,1,0,0,1,0,0,0,0,1,0,1,0,0,0,0,0,1,1,0,0,1,0,0,1},
  {0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,1,1,0,1,1,0,1,0,0,1,0,1,0,0,0,1,0,1,1,1,1,1,0,1,0,0,1,1,1,1,1,1,0,0,1,1,1,0,0,1},
  {0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,1,1,0,1,1,0,1,1,1,1,1,0,0,1,1,0,1,0,1,1,0,0,0,0,0,1,0,1,1,1,1,1,1,0,0,0,0,1,1,0,0,0,0,0,1},
  {0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,1,0,1,0,1,1,1,1,1,0,0,1,0,1,0,1,1,0,1,1,0,0,0,0,1,0,0,1,0,1,1,1,0,0,1,1,1,1,0,1},
  {0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,1,0,0,0,1,1,0,0,0,1,1,0,1,0,0,0,1,1,0,1,1,0,1,0,0,1,0,0,1,0,0,1,1,0,0,0,0,1,1},
  {0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,1,1,1,0,1,0,0,1,1,0,0,1,0,1,1,1,0,0,1,0,0,1,0,0,0,0,0,1,1,1,1,1,1,1,1,0,1,0,0,1,1,1,1,0,0},
  {0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,1,1,1,0,1,0,0,1,1,0,0,1,0,1,1,1,0,0,1,0,0,1,0,0,0,0,0,1,1,1,1,1,1,1,1,0,1,0,0,1,1,1,1,0},
  {0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,1,1,1,0,1,0,0,1,1,0,0,1,0,1,1,1,0,0,1,0,0,1,0,0,0,0,0,1,1,1,1,1,1,1,1,0,1,0,0,1,1,1,1},
  {0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,1,0,1,1,0,1,1,1,1,0,1,0,0,0,1,0,0,0,1,0,0,1,0,1,0,0,0,1,0,0,0,0,1,1,1,1,1,0,1,0},
  {0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,1,0,1,1,0,1,1,1,1,0,1,0,0,0,1,0,0,0,1,0,0,1,0,1,0,0,0,1,0,0,0,0,1,1,1,1,1,0,1},
  {0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,1,0,1,1,1,0,0,1,1,1,1,0,0,0,0,0,1,0,1,1,0,0,0,1,1,1,1,0,0,0,0,1,0,1,0,1,1,0,1,1,0,0,0,1,1},
  {0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,1,1,0,0,0,1,0,0,0,1,1,1,1,1,1,1,1,0,1,0,1,1,0,1,0,1,0,1,1,0,1,1,1,1,0,1,0,1,1,1,0,1,1,0,0},
  {0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,1,1,0,0,0,1,0,0,0,1,1,1,1,1,1,1,1,0,1,0,1,1,0,1,0,1,0,1,1,0,1,1,1,1,0,1,0,1,1,1,0,1,1,0},
  {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,1,1,0,0,0,1,0,0,0,1,1,1,1,1,1,1,1,0,1,0,1,1,0,1,0,1,0,1,1,0,1,1,1,1,0,1,0,1,1,1,0,1,1},
};

static vector<vector<int>> block_56_12_22 = {
  { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1 },
  { 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0 },
  { 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0 },
  { 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0 },
  { 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1 },
  { 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0 },
  { 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1 },
  { 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0 },
  { 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1 },
  { 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0 },
  { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1 },
  { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1 },
};

static vector<vector<int>> gen_matrix;
static size_t              gen_in_count = 0;
static size_t              gen_out_count = 0;

size_t
short_code_init (size_t k)
{
  if (k == 12)
    {
      gen_matrix    = block_56_12_22;
      gen_in_count  = 12;
      gen_out_count = 56;
    }
  else if (k == 16)
    {
      gen_matrix    = block_61_16_21;
      gen_in_count  = 16;
      gen_out_count = 61;
    }
  else if (k == 20)
    {
      gen_matrix    = block_65_20_20;
      gen_in_count  = 20;
      gen_out_count = 65;
    }
  else /* unsupported k */
    {
      return 0;
    }
  return gen_out_count;
}

vector<int>
code_encode (ConvBlockType block_type, const vector<int>& in_bits)
{
  return Params::payload_short ? short_encode (block_type, in_bits) : conv_encode (block_type, in_bits);
}

size_t
code_size (ConvBlockType block_type, size_t msg_size)
{
  return Params::payload_short ? short_code_size (block_type, msg_size) : conv_code_size (block_type, msg_size);
}

vector<int>
code_decode_soft (ConvBlockType block_type, const std::vector<float>& coded_bits, float *error_out)
{
  return Params::payload_short ? short_decode_soft (block_type, coded_bits, error_out) : conv_decode_soft (block_type, coded_bits, error_out);
}

vector<int>
short_encode_blk (const vector<int>& in_bits)
{
  assert (gen_matrix.size() == in_bits.size());
  assert (gen_matrix.size() == gen_in_count);
  assert (gen_matrix[0].size() == gen_out_count);

  vector<int> out_bits;
  for (size_t j = 0; j < gen_out_count; j++)
    {
      int x = 0;
      for (size_t bit = 0; bit < gen_in_count; bit++)
        {
          if (in_bits[bit])
            {
              x ^= gen_matrix[bit][j];
            }
        }
      out_bits.push_back (x);
    }
  return out_bits;
}

vector<int>
short_encode (ConvBlockType block_type, const vector<int>& in_bits)
{
  return conv_encode (block_type, short_encode_blk (in_bits));
}

size_t
short_code_size (ConvBlockType block_type, size_t msg_size)
{
  assert (msg_size == gen_matrix.size());

  return conv_code_size (block_type, gen_out_count);
}

vector<int>
short_decode_blk (const vector<int>& coded_bits)
{
  vector<int> out_bits;
  for (size_t c = 0; c < size_t (1 << gen_in_count); c++)
    {
      bool match = true;
      for (size_t j = 0; j < gen_out_count; j++)
        {
          int x = 0;
          for (size_t bit = 0; bit < gen_in_count; bit++)
            {
              if (c & (1 << bit))
                {
                  x ^= gen_matrix[bit][j];
                }
            }
          if (coded_bits[j] != x)
            {
              match = false;
              break;
            }
        }
      if (match)
        {
          for (size_t bit = 0; bit < gen_in_count; bit++)
            {
              if (c & (1 << bit))
                {
                  out_bits.push_back (1);
                }
              else
                {
                  out_bits.push_back (0);
                }
            }
          return out_bits;
        }
    }

  return out_bits;
}

vector<int>
short_decode_soft (ConvBlockType block_type, const std::vector<float>& coded_bits, float *error_out)
{
  return short_decode_blk (conv_decode_soft (block_type, coded_bits, error_out));
}
